import torch
import torch.nn as nn
import networkx as nx
import numpy as np
from transformers import HieraModel
import math
from tensorboardX import SummaryWriter  # 导入 tensorboardX
from sklearn.preprocessing import StandardScaler

from find_subgraph import find_power_of_two_connected_subgraphs_optimized


class InferenceTimePredictor(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(InferenceTimePredictor, self).__init__()
        layers = []
        in_size = input_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(in_size, hidden_size))
            layers.append(nn.ReLU())
            in_size = hidden_size
        layers.append(nn.Linear(in_size, output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def prepare_input(G, A, h, w):
    node_compute_powers = []
    min_communication_bandwidths = []
    rental_costs = []
    for node in A:
        node_compute_powers.append(G.nodes[node].get('compute_power', 0))
        rental_costs.append(G.nodes[node].get('rental_cost', 0))
        node_bandwidths = []
        for other_node in A:
            if node != other_node:
                bandwidth = G.get_edge_data(node, other_node, default={'bandwidth': 0})['bandwidth']
                node_bandwidths.append(bandwidth)
        if node_bandwidths:
            min_communication_bandwidths.append(min(node_bandwidths))
        else:
            min_communication_bandwidths.append(0)

    min_compute_power = min(node_compute_powers)
    min_communication_bandwidth = min(min_communication_bandwidths)
    # 增加集合A的大小作为特征
    set_size = len(A)

    input_features = [min_compute_power, min_communication_bandwidth, h, w, set_size]
    input_tensor = torch.tensor(input_features, dtype=torch.float32)
    return input_tensor, sum(rental_costs)


def train_model(model, train_loader, criterion, optimizer, epochs, save_path):
    model.train()
    writer = SummaryWriter()  # 创建 SummaryWriter 对象
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            # 确保输出和标签维度一致
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader)
        writer.add_scalar('Training Loss', epoch_loss, epoch)  # 记录损失
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

    # 保存训练好的权重
    torch.save(model.state_dict(), save_path)
    writer.close()  # 关闭 SummaryWriter


def infer(model, G, A, h, w, scaler):
    model.eval()
    input_tensor, rental_cost_sum = prepare_input(G, A, h, w)
    # 对输入进行归一化处理
    input_tensor = torch.tensor(scaler.transform([input_tensor.numpy()]), dtype=torch.float32)
    with torch.no_grad():
        outputs = model(input_tensor)
        predicted_time = outputs[0, 0].item()
        predicted_cost = predicted_time * rental_cost_sum
    return predicted_time, predicted_cost


def test(G, model_path, h, w, scaler):
    # 加载模型
    input_size = 5  # 因为增加了一个特征，输入大小变为5
    hidden_sizes = [20, 15]
    output_size = 2
    model = InferenceTimePredictor(input_size, hidden_sizes, output_size)
    model.load_state_dict(torch.load(model_path))

    # 找出所有可能的集合A
    subgraphs = find_power_of_two_connected_subgraphs_optimized(G)
    min_time = float('inf')
    min_cost = float('inf')
    min_time_set = None
    min_cost_set = None

    for subgraph in subgraphs:
        A = set(subgraph.nodes())
        predicted_time, predicted_cost = infer(model, G, A, h, w, scaler)
        if predicted_time < min_time:
            min_time = predicted_time
            min_time_set = A
        if predicted_cost < min_cost:
            min_cost = predicted_cost
            min_cost_set = A

    return min_time_set, min_cost_set, min_time, min_cost


if __name__ == '__main__':
    # 简单示例
    input_size = 5  # 因为增加了一个特征，输入大小变为5
    hidden_sizes = [20, 15]
    output_size = 2

    # 创建模型实例
    model = InferenceTimePredictor(input_size, hidden_sizes, output_size)

    # 示例图的创建
    G = nx.Graph()
    G.add_node(1, compute_power=100, rental_cost=8)
    G.add_node(2, compute_power=200, rental_cost=20)
    G.add_node(3, compute_power=150, rental_cost=15)
    G.add_node(4, compute_power=450, rental_cost=100)
    G.add_node(5, compute_power=150, rental_cost=15)
    G.add_node(6, compute_power=150, rental_cost=15)
    G.add_edge(1, 2, bandwidth=100)
    G.add_edge(2, 3, bandwidth=200)
    G.add_edge(2, 4, bandwidth=200)
    G.add_edge(1, 3, bandwidth=150)
    G.add_edge(4, 5, bandwidth=100)
    G.add_edge(5, 6, bandwidth=200)
    G.add_edge(4, 6, bandwidth=150)

    # 模拟训练数据
    # 这里只是简单示例，实际需要大量真实数据

    train_inputs = []
    train_labels = []

    subs = [{5}, {5}, {5}, {5, 6}, {5, 6}, {5, 6}, {2, 4, 5, 6}, {2, 4, 5, 6}, {2, 4, 5, 6}, {1}, {2}, {4}]
    H = [256, 512, 1024, 256, 512, 1024, 256, 512, 1024, 1024, 1024, 1024]
    W = [256, 512, 1024, 256, 512, 1024, 256, 512, 1024, 1024, 1024, 1024]
    predicted_times = [5, 9, 16, 2.94, 5.29, 9.41, 1.66, 3, 5.33, 24, 12, 5.33]
    for i in range(len(subs)):
        A = subs[i]
        h = H[i]
        w = W[i]
        input_tensor, rental_cost_sum = prepare_input(G, A, h, w)
        train_inputs.append(input_tensor)
        # 模拟标签，实际需要真实推理时间和成本
        predicted_time = predicted_times[i]
        predicted_cost = predicted_time * rental_cost_sum
        train_labels.append(torch.tensor([predicted_time, predicted_cost], dtype=torch.float32))

    train_inputs = torch.stack(train_inputs)
    # 对输入数据进行归一化处理
    scaler = StandardScaler()

    train_inputs_np = train_inputs.numpy()
    train_inputs_np = scaler.fit_transform(train_inputs_np)
    train_inputs = torch.tensor(train_inputs_np, dtype=torch.float32)

    train_labels = torch.stack(train_labels)
    train_dataset = torch.utils.data.TensorDataset(train_inputs, train_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

    # 训练模型
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    epochs = 500
    save_path = 'model_weights.pth'
    train_model(model, train_loader, criterion, optimizer, epochs, save_path)

    # 测试
    # 示例输入
    h = 512
    w = 512
    min_time_set, min_cost_set, min_time, min_cost = test(G, save_path, h, w, scaler)
    print(f"推理时间最短的集合: {min_time_set, min_time}")
    print(f"成本最低的集合: {min_cost_set, min_cost}")